Tank classifier models:

Download imges for training

tank_types = 'merkava mk4','M1 Abrams','water'
path = Path('tanks')

#downloading 150 images to labeled directories
if not path.exists():
    path.mkdir()
    for o in tank_types:
        dest = (path/o)
        dest.mkdir(exist_ok=True)
        urls = search_images_ddg(f'{o} tank', max_images=150)
        download_images(dest, urls=urls)

#deleting not working files
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink);

Bulding a model using multi class loss

Preparing the data for the model

tanks = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())

dls = tanks.dataloaders(path)

dls.valid.show_batch(max_n=5, nrows=1)

Training the model and review the results

learnMC = cnn_learner(dls, resnet18, metrics=error_rate)
learnMC.fine_tune(4) #fastai already chooses multi class loss becuase od the CategoryBlock
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
epoch train_loss valid_loss error_rate time
0 1.476966 0.376781 0.144231 00:45
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
epoch train_loss valid_loss error_rate time
0 0.504485 0.251117 0.125000 00:24
1 0.339178 0.157399 0.057692 00:26
2 0.282746 0.184535 0.067308 00:27
3 0.231388 0.182214 0.067308 00:25
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "

Review top 10 loss images:

Bulding a model using multi label loss

Preparing the data for the model

def parentlabel(x):
  return [x.parent.name] # as get_y recieve a list
tanks2 = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock), 
    #MultiCategoryBlock(add_na=True)
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parentlabel,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())

dls2 = tanks2.dataloaders(path)

dls2.valid.show_batch(nrows=1, ncols=5)

Training the model and review the results

learnML = cnn_learner(dls2, resnet18, metrics=accuracy_multi) 
#defaults for the accuracy multi are: threshold=0.5, Sigmod=True
learnML.fine_tune(4)
epoch train_loss valid_loss accuracy_multi time
0 0.903184 0.472175 0.814103 00:24
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
epoch train_loss valid_loss accuracy_multi time
0 0.454756 0.282619 0.868590 00:25
1 0.372372 0.182789 0.939103 00:33
2 0.299630 0.145639 0.948718 00:30
3 0.243434 0.141897 0.945513 00:28
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "

Addjusting the threshhold (as the curve is smooth we shuold nit be worried about overfitting):

preds, targs = learnML.get_preds()
xs = torch.linspace(0.05,0.95,29)
accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]
plt.plot(xs,accs);
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
learnML.metrics = partial(accuracy_multi, thresh=0.85)
learnML.validate()
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
(#2) [0.14189694821834564,0.9615384340286255]

Review top 10 loss images:

target predicted probabilities loss
0 merkava mk4 M1 Abrams TensorBase([0.9994, 0.0057, 0.0015]) 4.218245983123779
1 M1 Abrams merkava mk4 TensorBase([0.3920, 0.9105, 0.0013]) 1.1169370412826538
2 M1 Abrams M1 Abrams;merkava mk4 TensorBase([0.8152, 0.9213, 0.0116]) 0.9194316267967224
3 merkava mk4 M1 Abrams;merkava mk4 TensorBase([9.2422e-01, 9.5859e-01, 8.6757e-04]) 0.874360978603363
4 merkava mk4 M1 Abrams TensorBase([0.7902, 0.4187, 0.0231]) 0.8185343742370605
5 merkava mk4 merkava mk4;water TensorBase([0.0296, 0.8759, 0.7703]) 0.5444540977478027
6 M1 Abrams M1 Abrams;merkava mk4 TensorBase([0.6383, 0.6916, 0.0058]) 0.543639063835144
7 water M1 Abrams;water TensorBase([7.7097e-01, 6.5688e-04, 9.9986e-01]) 0.4915701746940613
8 M1 Abrams M1 Abrams;merkava mk4 TensorBase([0.6539, 0.5428, 0.0217]) 0.40980264544487
9 M1 Abrams M1 Abrams;merkava mk4 TensorBase([9.8275e-01, 6.9141e-01, 8.7449e-04]) 0.398002952337265

To use our model in an application, we can simply treat the predict method as a regular function.

Addtional validation (out of train scope images)

The first model is binary and forces to select a single class based on any image input. Its softmax encourages a single class-selection even more. While the second model is multi-label and predicts each class separately, it may be better at alerting images outside the scope of the class or two classes within one image.

Now let's see if one model can perform better on out-of-scope images even though they were both trained on the same dataset:

multi_learner2.predict('/content/DUDU.jpg')
multi_learner2.predict('/content/TetroWtank.jpg')
multi_learner.predict('/content/manytanks.jpg')
multi_learner.predict('/content/merk_abrms2.jpeg')
multi_learner.predict('/content/merkav_abrams.jpg')
multi_learner.predict('/content/mek_abs_3.jpg')
 
#img.show()
#learn.predict(img)[0]
multi_learner.recorder.plot_loss()
multi_learner.show_results()
learn.recorder.plot_loss()
x, y = dls.one_batch()
x.shape
torch.Size([64, 3, 224, 224])
y
TensorCategory([1, 1, 0, 2, 0, 2, 0, 1, 2, 1, 1, 0, 2, 1, 0, 2, 2, 0, 2, 0, 2, 2, 1, 2, 2, 0, 0, 0, 2, 1, 2, 2, 0, 1, 1, 0, 1, 0, 2, 2, 2, 1, 0, 0, 2, 0, 2, 2, 2, 1, 1, 2, 2, 2, 0, 2, 0, 2, 2, 1, 1, 2, 2, 0],
       device='cuda:0')
learn.metrics[0]
<fastai.learner.AvgMetric at 0x7f2f37007590>
learn.loss_func
FlattenedLoss of CrossEntropyLoss()